{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Emotion classification multiclass example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates how to use the `Partition` explainer for a multiclass text classification scenario. Once the SHAP values are computed for a set of sentences we then visualize feature attributions towards individual classes. The text classifcation model we use is BERT fine-tuned on an emotion dataset to classify a sentence among six classes: joy, sadness, anger, fear, love and surprise." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default\n", "Reusing dataset emotion (/home/slundberg/.cache/huggingface/datasets/emotion/default/0.0.0/aa34462255cd487d04be8387a2d572588f6ceee23f784f37365aa714afeb8fe6)\n" ] } ], "source": [ "import datasets\n", "import pandas as pd\n", "import transformers\n", "\n", "import shap\n", "\n", "# load the emotion dataset\n", "dataset = datasets.load_dataset(\"emotion\", split=\"train\")\n", "data = pd.DataFrame({\"text\": dataset[\"text\"], \"emotion\": dataset[\"label\"]})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build a transformers pipline\n", "\n", "Note that we have set `return_all_scores=True` for the pipeline so we can observe the model's behavior for all classes, not just the top output." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# load the model and tokenizer\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(\"nateraw/bert-base-uncased-emotion\", use_fast=True)\n", "model = transformers.AutoModelForSequenceClassification.from_pretrained(\"nateraw/bert-base-uncased-emotion\").cuda()\n", "\n", "# build a pipeline object to do predictions\n", "pred = transformers.pipeline(\n", " \"text-classification\",\n", " model=model,\n", " tokenizer=tokenizer,\n", " device=0,\n", " return_all_scores=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create an explainer for the pipeline\n", "\n", "A transformers pipeline object can be passed directly to `shap.Explainer`, which will then wrap the pipeline model as a `shap.models.TransformersPipeline` model and the pipeline tokenizer as a `shap.maskers.Text` masker." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "explainer = shap.Explainer(pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute SHAP values\n", "\n", "Explainers have the same method signature as the models they are explaining, so we just pass a list of strings for which to explain the classifications." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "shap_values = explainer(data[\"text\"][:3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the impact on all the output classes\n", "\n", "In the plots below, when you hover your mouse over an output class you get the explanation for that output class. When you click an output class name then that class remains the focus of the explanation visualization until you click another class.\n", "\n", "The base value is what the model outputs when the entire input text is masked, while $f_{output class}(inputs)$ is the output of the model for the full original input. The SHAP values explain in an addive way how the impact of unmasking each word changes the model output from the base value (where the entire input is masked) to the final prediction value." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
[0]
\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "
outputs
\n", "
sadness
\n", "
joy
\n", "
love
\n", "
anger
\n", "
fear
\n", "
surprise


0.30.1-0.1-0.30.50.70.90.1316720.131672base value0.9964080.996408fsadness(inputs)0.855 humiliated 0.009 didn 0.003 i 0.001 t 0.0 -0.004 feel -0.0
inputs
-0.0
0.003
i
0.009
didn
0.001
t
-0.004
feel
0.855
humiliated
0.0